#!/usr/bin/env python3
# ex:ts=4:sw=4:sts=4:et
# -*- tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*-
#
# patchtest: execute all unittest test cases discovered for a single patch
#
# Copyright (C) 2016 Intel Corporation
#
# SPDX-License-Identifier: GPL-2.0-only
#

import json
import logging
import os
import sys
import traceback
import unittest

# Include current path so test cases can see it
sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)))

# Include patchtest library
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), '../meta/lib/patchtest'))

from patchtest_parser import PatchtestParser
from repo import PatchTestRepo

logger = logging.getLogger("patchtest")
loggerhandler = logging.StreamHandler()
loggerhandler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(loggerhandler)
logger.setLevel(logging.INFO)
info = logger.info
error = logger.error

def getResult(patch, mergepatch, logfile=None):

    class PatchTestResult(unittest.TextTestResult):
        """ Patchtest TextTestResult """
        shouldStop  = True
        longMessage = False

        success     = 'PASS'
        fail        = 'FAIL'
        skip        = 'SKIP'

        def startTestRun(self):
            # let's create the repo already, it can be used later on
            repoargs = {
                "repodir": PatchtestParser.repodir,
                "commit": PatchtestParser.basecommit,
                "branch": PatchtestParser.basebranch,
                "patch": patch,
            }

            self.repo_error    = False
            self.test_error    = False
            self.test_failure  = False

            try:
                self.repo = PatchtestParser.repo = PatchTestRepo(**repoargs)
            except:
                logger.error(traceback.print_exc())
                self.repo_error = True
                self.stop()
                return

            if mergepatch:
                self.repo.merge()

        def addError(self, test, err):
            self.test_error = True
            (ty, va, trace) = err
            logger.error(traceback.print_exc())

        def addFailure(self, test, err):
            test_description = test.id().split('.')[-1].replace('_', ' ').replace("cve", "CVE").replace("signed off by",
            "Signed-off-by").replace("upstream status",
            "Upstream-Status").replace("non auh",
            "non-AUH").replace("presence format", "presence")
            self.test_failure = True
            fail_str = '{}: {}: {} ({})'.format(self.fail,
            test_description, json.loads(str(err[1]))["issue"],
            test.id())
            print(fail_str)
            if logfile:
                with open(logfile, "a") as f:
                    f.write(fail_str + "\n")

        def addSuccess(self, test):
            test_description = test.id().split('.')[-1].replace('_', ' ').replace("cve", "CVE").replace("signed off by",
            "Signed-off-by").replace("upstream status",
            "Upstream-Status").replace("non auh",
            "non-AUH").replace("presence format", "presence")
            success_str = '{}: {} ({})'.format(self.success,
            test_description, test.id())
            print(success_str)
            if logfile:
                with open(logfile, "a") as f:
                    f.write(success_str + "\n")

        def addSkip(self, test, reason):
            test_description = test.id().split('.')[-1].replace('_', ' ').replace("cve", "CVE").replace("signed off by",
            "Signed-off-by").replace("upstream status",
            "Upstream-Status").replace("non auh",
            "non-AUH").replace("presence format", "presence")
            skip_str = '{}: {}: {} ({})'.format(self.skip,
            test_description, json.loads(str(reason))["issue"],
            test.id())
            print(skip_str)
            if logfile:
                with open(logfile, "a") as f:
                    f.write(skip_str + "\n")

        def stopTestRun(self):

            # in case there was an error on repo object creation, just return
            if self.repo_error:
                return

            self.repo.clean()

    return PatchTestResult

def _runner(resultklass, prefix=None):
    # load test with the corresponding prefix
    loader = unittest.TestLoader()
    if prefix:
        loader.testMethodPrefix = prefix

    # create the suite with discovered tests and the corresponding runner
    suite = loader.discover(
        start_dir=PatchtestParser.testdir,
        pattern=PatchtestParser.pattern,
        top_level_dir=PatchtestParser.topdir,
    )
    ntc = suite.countTestCases()

    # if there are no test cases, just quit
    if not ntc:
        return 2
    runner = unittest.TextTestRunner(resultclass=resultklass, verbosity=0)

    try:
        result = runner.run(suite)
    except:
        logger.error(traceback.print_exc())
        logger.error('patchtest: something went wrong')
        return 1
    if result.test_failure or result.test_error:
        return 1 

    return 0

def run(patch, logfile=None):
    """ Load, setup and run pre and post-merge tests """
    # Get the result class and install the control-c handler
    unittest.installHandler()

    # run pre-merge tests, meaning those methods with 'pretest' as prefix
    premerge_resultklass = getResult(patch, False, logfile)
    premerge_result = _runner(premerge_resultklass, 'pretest')

    # run post-merge tests, meaning those methods with 'test' as prefix
    postmerge_resultklass = getResult(patch, True, logfile)
    postmerge_result = _runner(postmerge_resultklass, 'test')

    print_result_message(premerge_result, postmerge_result)
    return premerge_result or postmerge_result

def print_result_message(preresult, postresult):
    print("----------------------------------------------------------------------\n")
    if preresult == 2 and postresult == 2:
        logger.error(
            "patchtest: No test cases found - did you specify the correct suite directory?"
        )
    if preresult == 1 or postresult == 1:
        logger.error(
            "WARNING: patchtest: At least one patchtest caused a failure or an error - please check https://wiki.yoctoproject.org/wiki/Patchtest for further guidance"
        )
    else:
        logger.info("OK: patchtest: All patchtests passed")
    print("----------------------------------------------------------------------\n")

def main():
    tmp_patch = False
    patch_path = PatchtestParser.patch_path
    log_results = PatchtestParser.log_results
    log_path = None
    patch_list = None

    git_status = os.popen("(cd %s && git status)" % PatchtestParser.repodir).read()
    status_matches = ["Changes not staged for commit", "Changes to be committed"]
    if any([match in git_status for match in status_matches]):
        logger.error("patchtest: there are uncommitted changes in the target repo that would be overwritten. Please commit or restore them before running patchtest")
        return 1

    if os.path.isdir(patch_path):
        patch_list = [os.path.join(patch_path, filename) for filename in sorted(os.listdir(patch_path))]
    else:
        patch_list = [patch_path]

    for patch in patch_list:
        if os.path.getsize(patch) == 0:
            logger.error('patchtest: patch is empty')
            return 1

        logger.info('Testing patch %s' % patch)

        if log_results:
            log_path = patch + ".testresult"
            with open(log_path, "a") as f:
                f.write("Patchtest results for patch '%s':\n\n" % patch)

        try:
            if log_path:
                run(patch, log_path)
            else:
                run(patch)
        finally:
            if tmp_patch:
                os.remove(patch)

if __name__ == '__main__':
    ret = 1

    # Parse the command line arguments and store it on the PatchtestParser namespace
    PatchtestParser.set_namespace()

    # set debugging level
    if PatchtestParser.debug:
        logger.setLevel(logging.DEBUG)

    # if topdir not define, default it to testdir
    if not PatchtestParser.topdir:
        PatchtestParser.topdir = PatchtestParser.testdir

    try:
        ret = main()
    except Exception:
        import traceback
        traceback.print_exc(5)

    sys.exit(ret)
